{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Search White House Speeches from 2021 to 2022 Based On Content\n",
    "\n",
    "A semantic search example based on White House Speeches from 2021 to 2022. Many of these speeches were made after GPT 3.5 was trained. The White House (Speeches and Remarks) 12/10/2022 dataset can be found on [Kaggle](https://www.kaggle.com/datasets/mohamedkhaledelsafty/the-white-house-speeches-and-remarks-12102022). For this example, we've also made this available on Google Drive. We put together a system to semantically search these speeches using a vector database and the sentence-transformers library. For this example, we use [Milvus Lite](https://milvus.io/docs/milvus_lite.md) to run our vector database locally.\n",
    "\n",
    "We begin by installing the necessary libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install pymilvus sentence-transformers gdown milvus"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download Dataset\n",
    "\n",
    "Next, we download and extract our dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gdown\n",
    "url = 'https://drive.google.com/uc?id=10_sVL0UmEog7mczLedK5s1pnlDOz3Ukf'\n",
    "output = './white_house_2021_2022.zip'\n",
    "gdown.download(url, output)\n",
    "\n",
    "import zipfile\n",
    "\n",
    "with zipfile.ZipFile(\"./white_house_2021_2022.zip\",\"r\") as zip_ref:\n",
    "    zip_ref.extractall(\"./white_house_2021_2022\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean the Data\n",
    "\n",
    "This dataset is not a precleaned dataset so we need to clean it up before we can work on it. Our first preprocessing step is to drop all rows with any `Null` or `NaN` data using `.dropna()`. Next, we ensure that we aren't picking up any partial speeches by only taking speeches that have more than 50 characters. We also get rid of all the return and newline characters in the speeches. Finally, we convert the dates into the universally accepted datetime format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv(\"./white_house_2021_2022/The white house speeches.csv\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_df = df.loc[(df[\"Speech\"].str.len() > 50)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_df[\"Speech\"] = cleaned_df[\"Speech\"].str.replace(\"\\r\\n\", \"\")\n",
    "cleaned_df.iloc[0][\"Speech\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "# Convert the 'date' column to datetime objects\n",
    "cleaned_df[\"Date_time\"] = pd.to_datetime(cleaned_df[\"Date_time\"], format=\"%B %d, %Y\")\n",
    "\n",
    "# Convert the datetime objects to Unix time format\n",
    "cleaned_df[\"unix_time\"] = cleaned_df[\"Date_time\"].apply(lambda x: int(x.timestamp()))\n",
    "\n",
    "cleaned_df"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Establish a Vector Database and Schema\n",
    "\n",
    "With all of our datacleaning done, it's time to set up our vector database, Milvus Lite. We start by declaring some constants before starting a server and establishing a connection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "COLLECTION_NAME = \"white_house_2021_2022\"\n",
    "DIMENSION = 384\n",
    "BATCH_SIZE = 128\n",
    "TOPK = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from milvus import default_server\n",
    "from pymilvus import connections, utility\n",
    "\n",
    "default_server.start()\n",
    "connections.connect(host=\"127.0.0.1\", port=default_server.listen_port)\n",
    "\n",
    "utility.get_server_version()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just to make sure that we are starting from a blank slate, we check for the existence of any collection with the same name as the one we chose and drop it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if utility.has_collection(COLLECTION_NAME):\n",
    "    utility.drop_collection(COLLECTION_NAME)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we establish our schema. For this data set, we have four attributes to work off - the title of the speech, the date the speech was given, the location where the speech was given, and the speech itself. We want to perform a semantic search on the content of the actual speech so the schema will contain the title, the date, the location, and a vector embedding of the actual speech.\n",
    "\n",
    "For each `VARCHAR` datatype (string format) we give a max length. In this case, none of these max lengths are hit, but serve as a rough upper bound estimate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymilvus import FieldSchema, CollectionSchema, DataType, Collection\n",
    "\n",
    "# object should be inserted in the format of (title, date, location, speech embedding)\n",
    "fields = [\n",
    "    FieldSchema(name=\"id\", dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
    "    FieldSchema(name=\"title\", dtype=DataType.VARCHAR, max_length=500),\n",
    "    FieldSchema(name=\"date\", dtype=DataType.VARCHAR, max_length=100),\n",
    "    FieldSchema(name=\"location\", dtype=DataType.VARCHAR, max_length=200),\n",
    "    FieldSchema(name=\"embedding\", dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)\n",
    "]\n",
    "schema = CollectionSchema(fields=fields)\n",
    "collection = Collection(name=COLLECTION_NAME, schema=schema)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With a vector database server up and running as well as a collection and schema established, the final thing to do before inserting the vectors is to establish our vector index. For this example, we use an `IVF_FLAT` index on an `L2` distance metric and 128 clusters (`nlist`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_params = {\n",
    "    \"index_type\": \"IVF_FLAT\",\n",
    "    \"metric_type\": \"L2\",\n",
    "    \"params\": {\"nlist\": 128},\n",
    "}\n",
    "collection.create_index(field_name=\"embedding\", index_params=index_params)\n",
    "collection.load()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Vector Embeddings and Populate the Database \n",
    "\n",
    "Here we use the `SentenceTransformer` library to get our vector embeddings for the speeches and populate our Milvus instance with our newly generated vector embeddings. For this example, we use the [MiniLM L6 v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) transformer to get a vector embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "transformer = SentenceTransformer('all-MiniLM-L6-v2')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create a function, `embed_insert`, that gets the embeddings for a batch of speeches, and then inserts that batch into our Milvus instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# expects a tuple of (title, date, location, speech)\n",
    "def embed_insert(data: list):\n",
    "    embeddings = transformer.encode(data[3])\n",
    "    ins = [\n",
    "        data[0],\n",
    "        data[1],\n",
    "        data[2],\n",
    "        [x for x in embeddings]\n",
    "    ]\n",
    "    collection.insert(ins)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With our helper function written, we are ready to embed and insert the text. First, we turn our `pandas` dataframe into the right format, a list of lists, to insert. For this example, we need a list of four lists. The inner lists correspond to the title, date, location, and speech respectively. We batch the lists and call the `embed_insert` function we wrote above on each of them. Finally, when all of the data has been inserted, we `flush` the collection to ensure that everything is indexed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_batch = [[], [], [], []]\n",
    "\n",
    "for index, row in cleaned_df.iterrows():\n",
    "    data_batch[0].append(row[\"Title\"])\n",
    "    data_batch[1].append(str(row[\"Date_time\"]))\n",
    "    data_batch[2].append(row[\"Location\"])\n",
    "    data_batch[3].append(row[\"Speech\"])\n",
    "    if len(data_batch[0]) % BATCH_SIZE == 0:\n",
    "        embed_insert(data_batch)\n",
    "        data_batch = [[], [], [], []]\n",
    "\n",
    "# Embed and insert the remainder\n",
    "if len(data_batch[0]) != 0:\n",
    "    embed_insert(data_batch)\n",
    "\n",
    "# Call a flush to index any unsealed segments.\n",
    "collection.flush()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run a Semantic Search\n",
    "\n",
    "With the database populated, it's now possible to search all of the speeches based on their content. In this example, we search for a speech where the President speaks about renewable energy at NREL, and a speech where the Vice President and the Prime Minister of Canada both speak. We get the embeddings for these descriptions, and then search our vector database for the 3 speeches with the closest embeddings. \n",
    "\n",
    "We expect the first description to have the speech titled \"Remarks by President Biden During a Tour of the National Renewable Energy Laboratory\" in its results and the second description to have the speech titled \"REMARKS BY VICE PRESIDENT HARRIS AND PRIME MINISTER TRUDEAU OF CANADA BEFORE BILATERAL MEETING\" in its results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "search_terms = [\"The President speaks about the impact of renewable energy at the National Renewable Energy Lab.\", \"The Vice President and the Prime Minister of Canada both speak.\"]\n",
    "\n",
    "# Search the database based on input text\n",
    "def embed_search(data):\n",
    "    embeds = transformer.encode(data) \n",
    "    return [x for x in embeds]\n",
    "\n",
    "search_data = embed_search(search_terms)\n",
    "\n",
    "start = time.time()\n",
    "res = collection.search(\n",
    "    data=search_data,  # Embeded search value\n",
    "    anns_field=\"embedding\",  # Search across embeddings\n",
    "    param={\"metric_type\": \"L2\",\n",
    "            \"params\": {\"nprobe\": 10}},\n",
    "    limit = TOPK,  # Limit to top_k results per search\n",
    "    output_fields=[\"title\"]  # Include title field in result\n",
    ")\n",
    "end = time.time()\n",
    "\n",
    "for hits_i, hits in enumerate(res):\n",
    "    print(\"Title:\", search_terms[hits_i])\n",
    "    print(\"Search Time:\", end-start)\n",
    "    print(\"Results:\")\n",
    "    for hit in hits:\n",
    "        print( hit.entity.get(\"title\"), \"----\", hit.distance)\n",
    "    print()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clean up the server."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "default_server.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hw_milvus",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
